import pysam
from Bio.Seq import Seq
import random
import pandas as pd
import os

def create_test_primers():
    """Create test primers file with known sequences"""
    primers = {
        'Name': ['Primer1', 'Primer2', 'Primer3'],
        'Forward': [
            'ACGTACGAGATAGATAGATATGAGAATTACGT',
            'GCTAGCTAACTAGATTTCGTAGATCGAGCTA',
            'TGCATCGATGTGATACACACGCATGCA'
        ],
        'Reverse': [
            'TGCATGGTGATGATCACATATAGCTAGATGCA',
            'CGTAGCTCGTAGCTAGCTAGCTCGTAGCTAAGCTA',
            'ACGTACGTGATGCCGCGCTAGCTAGTACGT'
        ],
        'Size': [200, 250, 400]  # Expected amplicon sizes
    }
    
    df = pd.DataFrame(primers)
    df.to_csv('test_primers.tsv', sep='\t', index=False)
    return primers

def generate_perfect_read(forward_primer, reverse_primer, size):
    """Generate a perfect read with correct orientation and size"""
    middle_length = size - len(forward_primer) - len(reverse_primer)
    middle = ''.join(random.choice('ACGT') for _ in range(middle_length))
    # Forward primer at start, reverse complement of reverse primer at end
    sequence = forward_primer + middle + str(Seq(reverse_primer).reverse_complement())
    return sequence

def generate_wrong_size_read(forward_primer, reverse_primer, target_size):
    """Generate a read with correct primers but wrong size"""
    wrong_size = target_size + int(target_size * 0.2)  # 20% larger
    middle_length = wrong_size - len(forward_primer) - len(reverse_primer)
    middle = ''.join(random.choice('ACGT') for _ in range(middle_length))
    sequence = forward_primer + middle + str(Seq(reverse_primer).reverse_complement())
    return sequence

def generate_incorrect_orientation(forward_primer, reverse_primer, size):
    """Generate a read with reversed primer orientation"""
    middle_length = size - len(forward_primer) - len(reverse_primer)
    middle = ''.join(random.choice('ACGT') for _ in range(middle_length))
    # Reverse primer first, forward primer at end
    sequence = reverse_primer + middle + forward_primer
    return sequence

def create_test_bam():
    """Create test BAM file with exactly 1000 reads matching the expected proportions"""
    header = {'HD': {'VN': '1.0'},
             'SQ': [{'LN': 1000, 'SN': 'test_ref'}]}
    
    with pysam.AlignmentFile('test_reads.bam', 'wb', header=header) as outf:
        read_count = 0
        
        # For each primer pair
        for name, fwd, rev, size in zip(primers['Name'], 
                                      primers['Forward'], 
                                      primers['Reverse'], 
                                      primers['Size']):
            # 1. Perfect matches (80 total, ~27 per primer pair)
            for i in range(27):
                sequence = generate_perfect_read(fwd, rev, size)
                a = pysam.AlignedSegment()
                a.query_name = f"{name}_perfect_{i}"
                a.query_sequence = sequence
                a.flag = 4
                a.is_unmapped = True
                outf.write(a)
                read_count += 1
            
            # 2. Wrong size matches (80 total, ~27 per primer pair)
            for i in range(27):
                sequence = generate_wrong_size_read(fwd, rev, size)
                a = pysam.AlignedSegment()
                a.query_name = f"{name}_wrong_size_{i}"
                a.query_sequence = sequence
                a.flag = 4
                a.is_unmapped = True
                outf.write(a)
                read_count += 1
            
            # 3. Incorrect orientation (120 total, 40 per primer pair)
            for i in range(40):
                sequence = generate_incorrect_orientation(fwd, rev, size)
                a = pysam.AlignedSegment()
                a.query_name = f"{name}_incorrect_orient_{i}"
                a.query_sequence = sequence
                a.flag = 4
                a.is_unmapped = True
                outf.write(a)
                read_count += 1
            
            # 4. Single-end primers (167 forward and reverse per primer pair)
            for i in range(167):
                # Forward only
                sequence = fwd + ''.join(random.choice('ACGT') for _ in range(100))
                a = pysam.AlignedSegment()
                a.query_name = f"{name}_fwd_only_{i}"
                a.query_sequence = sequence
                a.flag = 4
                a.is_unmapped = True
                outf.write(a)
                read_count += 1
            
        # 5. Mismatched primer pairs (120 reads)
        for i in range(120):
            sequence = primers['Forward'][0] + ''.join(random.choice('ACGT') for _ in range(200)) + \
                      str(Seq(primers['Reverse'][1]).reverse_complement())
            a = pysam.AlignedSegment()
            a.query_name = f"mismatched_pair_{i}"
            a.query_sequence = sequence
            a.flag = 4
            a.is_unmapped = True
            outf.write(a)
            read_count += 1
        
        # 6. No primers (100 reads)
        for i in range(100):
            sequence = ''.join(random.choice('ACGT') for _ in range(100))
            a = pysam.AlignedSegment()
            a.query_name = f"no_primers_{i}"
            a.query_sequence = sequence
            a.flag = 4
            a.is_unmapped = True
            outf.write(a)
            read_count += 1
    
    # Sort and index BAM file
    pysam.sort('-o', 'test_data.bam', 'test_reads.bam')
    os.remove('test_reads.bam')
    
    print(f"Total reads created: {read_count}")

if __name__ == "__main__":
    # Create test files
    primers = create_test_primers()
    create_test_bam()